import numpy as np
import torch
import logging
import threading
import time
import wave
from typing import List, Optional
from queue import SimpleQueue, Empty

from whisperlivekit.timed_objects import SpeakerSegment

logger = logging.getLogger(__name__)

try:
    from nemo.collections.asr.models import SortformerEncLabelModel
    from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
except ImportError:
    raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")


class StreamingSortformerState:
    """
    This class creates a class instance that will be used to store the state of the
    streaming Sortformer model.

    Attributes:
        spkcache (torch.Tensor): Speaker cache to store embeddings from start
        spkcache_lengths (torch.Tensor): Lengths of the speaker cache
        spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
        fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
        fifo_lengths (torch.Tensor): Lengths of the FIFO queue
        fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
        spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
        mean_sil_emb (torch.Tensor): Mean silence embedding
        n_sil_frames (torch.Tensor): Number of silence frames
    """

    def __init__(self):
        self.spkcache = None  # Speaker cache to store embeddings from start
        self.spkcache_lengths = None
        self.spkcache_preds = None  # speaker cache predictions
        self.fifo = None  # to save the embedding from the latest chunks
        self.fifo_lengths = None
        self.fifo_preds = None
        self.spk_perm = None
        self.mean_sil_emb = None
        self.n_sil_frames = None


class SortformerDiarization:
    def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
        """
        Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
        """
        self._load_model(model_name)
    
    def _load_model(self, model_name: str):
        """Load and configure the Sortformer model for streaming."""
        try:
            self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
            self.diar_model.eval()

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.diar_model.to(device)
            
            ## to test
            # for name, param in self.diar_model.named_parameters():
            #     if param.device != device:
            #         raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
            
            logger.info(f"Using {device.type.upper()} for Sortformer model")

            self.diar_model.sortformer_modules.chunk_len = 10
            self.diar_model.sortformer_modules.subsampling_factor = 10
            self.diar_model.sortformer_modules.chunk_right_context = 0
            self.diar_model.sortformer_modules.chunk_left_context = 10
            self.diar_model.sortformer_modules.spkcache_len = 188
            self.diar_model.sortformer_modules.fifo_len = 188
            self.diar_model.sortformer_modules.spkcache_update_period = 144
            self.diar_model.sortformer_modules.log = False
            self.diar_model.sortformer_modules._check_streaming_parameters()
                        
        except Exception as e:
            logger.error(f"Failed to load Sortformer model: {e}")
            raise
 
class SortformerDiarizationOnline:
    def __init__(self, shared_model, sample_rate: int = 16000):
        """
        Initialize the streaming Sortformer diarization system.
        
        Args:
            sample_rate: Audio sample rate (default: 16000)
            model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
        """
        self.sample_rate = sample_rate
        self.diarization_segments = []
        self.diar_segments = []
        self.buffer_audio = np.array([], dtype=np.float32)
        self.segment_lock = threading.Lock()
        self.global_time_offset = 0.0
        self.debug = False
                
        self.diar_model = shared_model.diar_model
             
        self.audio2mel = AudioToMelSpectrogramPreprocessor(
            window_size=0.025,
            normalize="NA",
            n_fft=512,
            features=128,
            pad_to=0
        )
        self.audio2mel.to(self.diar_model.device)
        
        self.chunk_duration_seconds = (
            self.diar_model.sortformer_modules.chunk_len * 
            self.diar_model.sortformer_modules.subsampling_factor * 
            self.diar_model.preprocessor._cfg.window_stride
        )
        
        self._init_streaming_state()
        
        self._previous_chunk_features = None
        self._chunk_index = 0
        self._len_prediction = None
        
        # Audio buffer to store PCM chunks for debugging
        self.audio_buffer = []
        
        # Buffer for accumulating audio chunks until reaching chunk_duration_seconds
        self.audio_chunk_buffer = []
        self.accumulated_duration = 0.0
        
        logger.info("SortformerDiarization initialized successfully")


    def _init_streaming_state(self):
        """Initialize the streaming state for the model."""
        batch_size = 1
        device = self.diar_model.device
        
        self.streaming_state = StreamingSortformerState()
        self.streaming_state.spkcache = torch.zeros(
            (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model), 
            device=device
        )
        self.streaming_state.spkcache_preds = torch.zeros(
            (batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk), 
            device=device
        )
        self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
        self.streaming_state.fifo = torch.zeros(
            (batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model), 
            device=device
        )
        self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
        self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
        self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)        
        self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)

    def insert_silence(self, silence_duration: Optional[float]):
        """
        Insert silence period by adjusting the global time offset.
        
        Args:
            silence_duration: Duration of silence in seconds
        """
        with self.segment_lock:
            self.global_time_offset += silence_duration
        logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")

    def insert_audio_chunk(self, pcm_array: np.ndarray):
        if self.debug:
            self.audio_buffer.append(pcm_array.copy())
        self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
  

    async def diarize(self):
        """
        Process audio data for diarization in streaming fashion.
        
        Args:
            pcm_array: Audio data as numpy array
        """

        threshold = int(self.chunk_duration_seconds * self.sample_rate)
        
        if not len(self.buffer_audio) >= threshold:
            return []
        
        audio = self.buffer_audio[:threshold]
        self.buffer_audio = self.buffer_audio[threshold:]
        
        device = self.diar_model.device
        audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
        audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
        
        processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
            audio_signal_chunk, audio_signal_length_chunk
        )
        processed_signal_chunk = processed_signal_chunk.to(device)
        processed_signal_length_chunk = processed_signal_length_chunk.to(device)
        
        if self._previous_chunk_features is not None:
            to_add = self._previous_chunk_features[:, :, -99:].to(device)
            total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
        else:
            total_features = processed_signal_chunk.to(device)
        
        self._previous_chunk_features = processed_signal_chunk.to(device)
        
        chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
        
        with torch.inference_mode():
            left_offset = 8 if self._chunk_index > 0 else 0
            right_offset = 8
            
            self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
                processed_signal=chunk_feat_seq_t,
                processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
                streaming_state=self.streaming_state,
                total_preds=self.total_preds,
                left_offset=left_offset,
                right_offset=right_offset,
            )                
        new_segments = self._process_predictions()
        
        self._chunk_index += 1
        return new_segments

    def _process_predictions(self):
        """Process model predictions and convert to speaker segments."""
        preds_np = self.total_preds[0].cpu().numpy()
        active_speakers = np.argmax(preds_np, axis=1)
        
        if self._len_prediction is None:
            self._len_prediction = len(active_speakers) #12
        
        frame_duration = self.chunk_duration_seconds / self._len_prediction
        current_chunk_preds = active_speakers[-self._len_prediction:]
        
        new_segments = []

        with self.segment_lock:
            base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
            current_spk = current_chunk_preds[0]
            start_time = round(base_time, 2)
            for idx, spk in enumerate(current_chunk_preds):
                current_time = round(base_time + idx * frame_duration, 2)
                if spk != current_spk:
                    new_segments.append(SpeakerSegment(
                        speaker=current_spk,
                        start=start_time,
                        end=current_time
                    ))
                    start_time = current_time
                    current_spk = spk
            new_segments.append(
                SpeakerSegment(
                        speaker=current_spk,
                        start=start_time,
                        end=current_time
            )
            )
        return new_segments
                
    def get_segments(self) -> List[SpeakerSegment]:
        """Get a copy of the current speaker segments."""
        with self.segment_lock:
            return self.diarization_segments.copy()

    def close(self):
        """Close the diarization system and clean up resources."""
        logger.info("Closing SortformerDiarization")
        with self.segment_lock:
            self.diarization_segments.clear()
        
        if self.debug:
            concatenated_audio = np.concatenate(self.audio_buffer)
            audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)                
            with wave.open("diarization_audio.wav", "wb") as wav_file:
                wav_file.setnchannels(1)  # mono audio
                wav_file.setsampwidth(2)   # 2 bytes per sample (int16)
                wav_file.setframerate(self.sample_rate)
                wav_file.writeframes(audio_data_int16.tobytes())
            logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")


def extract_number(s: str) -> int:
    """Extract number from speaker string (compatibility function)."""
    import re
    m = re.search(r'\d+', s)
    return int(m.group()) if m else 0


if __name__ == '__main__':
    import asyncio
    import librosa
    
    async def main():
        """TEST ONLY."""
        an4_audio = 'diarization_audio.wav'
        signal, sr = librosa.load(an4_audio, sr=16000)
        signal = signal[:16000*30]

        print("\n" + "=" * 50)
        print("ground truth:")
        print("Speaker 0: 0:00 - 0:09")
        print("Speaker 1: 0:09 - 0:19") 
        print("Speaker 2: 0:19 - 0:25")
        print("Speaker 0: 0:25 - 0:30")
        print("=" * 50)
        
        diarization_backend = SortformerDiarization()
        diarization = SortformerDiarizationOnline(shared_model = diarization_backend)     
        chunk_size = 1600
        
        for i in range(0, len(signal), chunk_size):
            chunk = signal[i:i+chunk_size]
            new_segments = await diarization.diarize(chunk)
            print(f"Processed chunk {i // chunk_size + 1}")
            print(new_segments)
        
        segments = diarization.get_segments()
        print("\nDiarization results:")
        for segment in segments:
            print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
    
    asyncio.run(main())
